#!/usr/bin/env python3
"""
Script to regenerate all standard plots (violin plots, ROC curves) for all experiments.
This ensures all experiments have the latest plot formatting with fixed x-axis labels.
"""

import argparse
import glob
import os
from pathlib import Path

from deception_detection.experiment import Experiment, ExperimentConfig
from deception_detection.postprocess import ExperimentPostprocessor
from deception_detection.prompt_scorer import PromptScorerMethod


def find_experiment_folders(search_dirs: list[str] | None = None) -> list[Path]:
    """
    Find all experiment folders that contain cfg.yaml files.
    
    Args:
        search_dirs: List of directories to search in. Defaults to ['results']        
    
    Returns:
        List of paths to experiment folders
    """
    if search_dirs is None:
        search_dirs = ['results']
    
    experiment_folders: list[Path] = []
    
    for search_dir in search_dirs:
        if not os.path.exists(search_dir):
            continue
            
        # Look for cfg.yaml files in subdirectories
        pattern = os.path.join(search_dir, '*', 'cfg.yaml')
        cfg_files = glob.glob(pattern)
        
        for cfg_file in cfg_files:
            exp_folder = Path(cfg_file).parent
            experiment_folders.append(exp_folder)
    
    return sorted(experiment_folders, key=lambda p: p.stat().st_mtime)


def regenerate_violin_plot_only(
    exp_path: Path, scoring_method: PromptScorerMethod = "mean"
) -> bool:
    """
    Regenerate only the violin plot for a single experiment (faster than full postprocessing).
    
    Args:
        exp_path: Path to the experiment folder
        scoring_method: Token reduction method to use
        
    Returns:
        True if successful, False otherwise
    """
    print(f"🔄 Regenerating violin plot for: {exp_path.name}")
    
    # Load experiment (this just reads config and scores.json - very fast)
    cfg = ExperimentConfig.from_path(exp_path)
    experiment = Experiment(cfg)
    
    # Create scorer and generate only violin plot
    from deception_detection.prompt_scorer import create_prompt_scorer
    from deception_detection.results.plot import violin_plot
    
    scorer = create_prompt_scorer(
        scoring_method, 
        experiment, 
        fprs=[0.01, 0.001], 
        default_fpr=0.01,
        control_dataset_name=cfg.control_data[0]
    )
    
    # Ensure plots directory exists
    (experiment.cfg.folder / "plots").mkdir(parents=True, exist_ok=True)
    
    # Generate only violin plot
    violin_plot(experiment, scorer)
    
    print(f"✅ Successfully regenerated violin plot for: {exp_path.name}")
    return True


def regenerate_plots_for_experiment(
    exp_path: Path,
    scoring_method: PromptScorerMethod = "mean",
    violin_only: bool = False,
) -> bool:
    """
    Regenerate all standard plots for a single experiment.
    
    Args:
        exp_path: Path to the experiment folder
        scoring_method: Token reduction method to use
        violin_only: If True, only regenerate violin plot (faster)
        
    Returns:
        True if successful, False otherwise
    """
    if violin_only:
        return regenerate_violin_plot_only(exp_path, scoring_method)
    
    print(f"🔄 Regenerating plots for: {exp_path.name}")
    
    # Load experiment
    cfg = ExperimentConfig.from_path(exp_path)
    experiment = Experiment(cfg)
    
    # Run postprocessing to generate all plots
    postprocessor = ExperimentPostprocessor(
        experiment, 
        scoring_method=scoring_method,
        include_black_box=False,
        control_dataset_name=cfg.control_data[0]
    )
    postprocessor.run()

    print(f"✅ Successfully regenerated plots for: {exp_path.name}")
    return True
        


def main():
    parser = argparse.ArgumentParser(description="Regenerate all standard plots for experiments")
    parser.add_argument("--search_dirs", type=str, nargs="+", default=["results"],
                       help="Directories to search for experiments (default: results)")
    parser.add_argument("--scoring_method", type=str, default="mean",
                       choices=["mean", "max", "relu", "topk", "final"],
                       help="Token reduction method to use (default: mean)")
    parser.add_argument(
        "--violin_only",
        action="store_true",
        help="Only regenerate violin plots (much faster, skips ROC curves and HTML export)",
    )
    # experimnt folder
    parser.add_argument("--experiment_folder", type=str, default=None,
                       help="Experiment folder to regenerate plots for")
    
    args = parser.parse_args()
    
    # List experiments if requested
    if args.experiment_folder is not None:
        experiment_folder = Path(args.experiment_folder)
    else:
        experiment_folders = find_experiment_folders(args.search_dirs)
        print("Available experiment folders:")
        for i, folder in enumerate(experiment_folders, 1):
            print(f"  {i}. {folder}")
        experiment_folder = experiment_folders[-1]
        
    print(f"Generating plots for experiment: {experiment_folder}")
    regenerate_plots_for_experiment(
        experiment_folder, args.scoring_method, violin_only=args.violin_only
    )

    


if __name__ == "__main__":
    main() 